# Copyright (c) 2015 Red Hat, Inc.
# Copyright (c) 2015 SUSE Linux Products GmbH
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import random

import eventlet
import mock
from oslo_config import cfg
from oslo_log import log as logging
from oslo_utils import uuidutils

from neutron.agent.common import config as agent_config
from neutron.agent.common import ovs_lib
from neutron.agent.l2.extensions import manager as ext_manager
from neutron.agent.linux import interface
from neutron.agent.linux import polling
from neutron.agent.linux import utils as agent_utils
from neutron.common import config as common_config
from neutron.common import constants as n_const
from neutron.common import utils
from neutron.plugins.common import constants as p_const
from neutron.plugins.ml2.drivers.openvswitch.agent.common import config \
    as ovs_config
from neutron.plugins.ml2.drivers.openvswitch.agent.common import constants
from neutron.plugins.ml2.drivers.openvswitch.agent.openflow.ovs_ofctl \
    import br_int
from neutron.plugins.ml2.drivers.openvswitch.agent.openflow.ovs_ofctl \
    import br_phys
from neutron.plugins.ml2.drivers.openvswitch.agent.openflow.ovs_ofctl \
    import br_tun
from neutron.plugins.ml2.drivers.openvswitch.agent import ovs_neutron_agent \
    as ovs_agent
from neutron.tests.common import net_helpers
from neutron.tests.functional.agent.linux import base

LOG = logging.getLogger(__name__)


class OVSAgentTestFramework(base.BaseOVSLinuxTestCase):

    def setUp(self):
        super(OVSAgentTestFramework, self).setUp()
        agent_rpc = ('neutron.plugins.ml2.drivers.openvswitch.agent.'
                     'ovs_neutron_agent.OVSPluginApi')
        mock.patch(agent_rpc).start()
        mock.patch('neutron.agent.rpc.PluginReportStateAPI').start()
        self.br_int = base.get_rand_name(n_const.DEVICE_NAME_MAX_LEN,
                                         prefix='br-int')
        self.br_tun = base.get_rand_name(n_const.DEVICE_NAME_MAX_LEN,
                                         prefix='br-tun')
        patch_name_len = n_const.DEVICE_NAME_MAX_LEN - len("-patch-tun")
        self.patch_tun = "%s-patch-tun" % self.br_int[patch_name_len:]
        self.patch_int = "%s-patch-int" % self.br_tun[patch_name_len:]
        self.ovs = ovs_lib.BaseOVS()
        self.config = self._configure_agent()
        self.driver = interface.OVSInterfaceDriver(self.config)
        self.namespace = self.useFixture(net_helpers.NamespaceFixture()).name

    def _get_config_opts(self):
        config = cfg.ConfigOpts()
        config.register_opts(common_config.core_opts)
        config.register_opts(interface.OPTS)
        config.register_opts(ovs_config.ovs_opts, "OVS")
        config.register_opts(ovs_config.agent_opts, "AGENT")
        agent_config.register_interface_driver_opts_helper(config)
        agent_config.register_agent_state_opts_helper(config)
        ext_manager.register_opts(config)
        return config

    def _configure_agent(self):
        config = self._get_config_opts()
        config.set_override(
            'interface_driver',
            'neutron.agent.linux.interface.OVSInterfaceDriver')
        config.set_override('integration_bridge', self.br_int, "OVS")
        config.set_override('ovs_integration_bridge', self.br_int)
        config.set_override('tunnel_bridge', self.br_tun, "OVS")
        config.set_override('int_peer_patch_port', self.patch_tun, "OVS")
        config.set_override('tun_peer_patch_port', self.patch_int, "OVS")
        config.set_override('host', 'ovs-agent')
        return config

    def _bridge_classes(self):
        return {
            'br_int': br_int.OVSIntegrationBridge,
            'br_phys': br_phys.OVSPhysicalBridge,
            'br_tun': br_tun.OVSTunnelBridge
        }

    def create_agent(self, create_tunnels=True):
        if create_tunnels:
            tunnel_types = [p_const.TYPE_VXLAN]
        else:
            tunnel_types = None
        bridge_mappings = ['physnet:%s' % self.br_int]
        self.config.set_override('tunnel_types', tunnel_types, "AGENT")
        self.config.set_override('polling_interval', 1, "AGENT")
        self.config.set_override('prevent_arp_spoofing', False, "AGENT")
        self.config.set_override('local_ip', '192.168.10.1', "OVS")
        self.config.set_override('bridge_mappings', bridge_mappings, "OVS")
        agent = ovs_agent.OVSNeutronAgent(self._bridge_classes(),
                                          self.config)
        self.addCleanup(self.ovs.delete_bridge, self.br_int)
        if tunnel_types:
            self.addCleanup(self.ovs.delete_bridge, self.br_tun)
        agent.sg_agent = mock.Mock()
        return agent

    def start_agent(self, agent, unplug_ports=None):
        if unplug_ports is None:
            unplug_ports = []
        self.setup_agent_rpc_mocks(agent, unplug_ports)
        polling_manager = polling.InterfacePollingMinimizer()
        self.addCleanup(polling_manager.stop)
        polling_manager.start()
        agent_utils.wait_until_true(
            polling_manager._monitor.is_active)
        agent.check_ovs_status = mock.Mock(
            return_value=constants.OVS_NORMAL)
        t = eventlet.spawn(agent.rpc_loop, polling_manager)

        def stop_agent(agent, rpc_loop_thread):
            agent.run_daemon_loop = False
            rpc_loop_thread.wait()

        self.addCleanup(stop_agent, agent, t)

    def _create_test_port_dict(self):
        return {'id': uuidutils.generate_uuid(),
                'mac_address': utils.get_random_mac(
                    'fa:16:3e:00:00:00'.split(':')),
                'fixed_ips': [{
                    'ip_address': '10.%d.%d.%d' % (
                         random.randint(3, 254),
                         random.randint(3, 254),
                         random.randint(3, 254))}],
                'vif_name': base.get_rand_name(
                    self.driver.DEV_NAME_LEN, self.driver.DEV_NAME_PREFIX)}

    def _create_test_network_dict(self):
        return {'id': uuidutils.generate_uuid(),
                'tenant_id': uuidutils.generate_uuid()}

    def _plug_ports(self, network, ports, agent, ip_len=24):
        for port in ports:
            self.driver.plug(
                network.get('id'), port.get('id'), port.get('vif_name'),
                port.get('mac_address'),
                agent.int_br.br_name, namespace=self.namespace)
            ip_cidrs = ["%s/%s" % (port.get('fixed_ips')[0][
                'ip_address'], ip_len)]
            self.driver.init_l3(port.get('vif_name'), ip_cidrs,
                                namespace=self.namespace)

    def _unplug_ports(self, ports, agent):
        for port in ports:
            self.driver.unplug(
                port.get('vif_name'), agent.int_br.br_name, self.namespace)

    def _get_device_details(self, port, network):
        dev = {'device': port['id'],
               'port_id': port['id'],
               'network_id': network['id'],
               'network_type': 'vlan',
               'physical_network': 'physnet',
               'segmentation_id': 1,
               'fixed_ips': port['fixed_ips'],
               'device_owner': 'compute',
               'port_security_enabled': True,
               'security_groups': ['default'],
               'admin_state_up': True}
        return dev

    def assert_bridge(self, br, exists=True):
        self.assertEqual(exists, self.ovs.bridge_exists(br))

    def assert_patch_ports(self, agent):

        def get_peer(port):
            return agent.int_br.db_get_val(
                'Interface', port, 'options', check_error=True)

        agent_utils.wait_until_true(
            lambda: get_peer(self.patch_int) == {'peer': self.patch_tun})
        agent_utils.wait_until_true(
            lambda: get_peer(self.patch_tun) == {'peer': self.patch_int})

    def assert_bridge_ports(self):
        for port in [self.patch_tun, self.patch_int]:
            self.assertTrue(self.ovs.port_exists(port))

    def assert_vlan_tags(self, ports, agent):
        for port in ports:
            res = agent.int_br.db_get_val('Port', port.get('vif_name'), 'tag')
            self.assertTrue(res)

    def _expected_plugin_rpc_call(self, call, expected_devices, is_up=True):
        """Helper to check expected rpc call are received

        :param call: The call to check
        :param expected_devices: The device for which call is expected
        :param is_up: True if expected_devices are devices that are set up,
               False if expected_devices are devices that are set down
        """
        if is_up:
            rpc_devices = [
                dev for args in call.call_args_list for dev in args[0][1]]
        else:
            rpc_devices = [
                dev for args in call.call_args_list for dev in args[0][2]]
        return not (set(expected_devices) - set(rpc_devices))

    def create_test_ports(self, amount=3, **kwargs):
        ports = []
        for x in range(amount):
            ports.append(self._create_test_port_dict(**kwargs))
        return ports

    def _mock_update_device(self, context, devices_up, devices_down, agent_id,
                            host=None):
        dev_up = []
        dev_down = []
        for port in self.ports:
            if devices_up and port['id'] in devices_up:
                dev_up.append(port['id'])
            if devices_down and port['id'] in devices_down:
                dev_down.append({'device': port['id'], 'exists': True})
        return {'devices_up': dev_up,
                'failed_devices_up': [],
                'devices_down': dev_down,
                'failed_devices_down': []}

    def setup_agent_rpc_mocks(self, agent, unplug_ports):
        def mock_device_details(context, devices, agent_id, host=None):
            details = []
            for port in self.ports:
                if port['id'] in devices:
                    dev = self._get_device_details(
                        port, self.network)
                    details.append(dev)
            ports_to_unplug = [x for x in unplug_ports if x['id'] in devices]
            if ports_to_unplug:
                self._unplug_ports(ports_to_unplug, self.agent)
            return {'devices': details, 'failed_devices': []}

        (agent.plugin_rpc.get_devices_details_list_and_failed_devices.
            side_effect) = mock_device_details
        agent.plugin_rpc.update_device_list.side_effect = (
            self._mock_update_device)

    def _prepare_resync_trigger(self, agent):
        def mock_device_raise_exception(context, devices_up, devices_down,
                                        agent_id, host=None):
            agent.plugin_rpc.update_device_list.side_effect = (
                self._mock_update_device)
            raise Exception('Exception to trigger resync')

        self.agent.plugin_rpc.update_device_list.side_effect = (
            mock_device_raise_exception)

    def wait_until_ports_state(self, ports, up, timeout=60):
        port_ids = [p['id'] for p in ports]
        agent_utils.wait_until_true(
            lambda: self._expected_plugin_rpc_call(
                self.agent.plugin_rpc.update_device_list, port_ids, up),
            timeout=timeout)

    def setup_agent_and_ports(self, port_dicts, create_tunnels=True,
                              trigger_resync=False):
        self.agent = self.create_agent(create_tunnels=create_tunnels)
        self.start_agent(self.agent)
        self.network = self._create_test_network_dict()
        self.ports = port_dicts
        if trigger_resync:
            self._prepare_resync_trigger(self.agent)
        self._plug_ports(self.network, self.ports, self.agent)
